"""
Packages for plotting of the CME simulation data with qutip

@author: Mikolaj Roguski
"""
import matplotlib.pyplot as plt
import numpy as np
import os
import re
from qutip import qload, plot_fock_distribution, destroy, expect, tensor, qeye, destroy, qsave, coherent, num
import matplotlib.patches as mpatches
from scipy.optimize import curve_fit


default_colors = [
    ('Blue', (0 / 255, 114 / 255, 178 / 255)),
    ('Reddish Purple', (204 / 255, 121 / 255, 167 / 255)),
    ('Vermillion', (213 / 255, 94 / 255, 0 / 255)),
    ('Bluish Green', (0 / 255, 158 / 255, 115 / 255)),
    ('Orange', (230 / 255, 159 / 255, 0 / 255)),
    ('Black', (0 / 255, 0 / 255, 0 / 255)),
    ('Sky Blue', (86 / 255, 180 / 255, 233 / 255)),
    ('Gray', (128 / 255, 128 / 255, 128 / 255)),
    ('Light Brown', (190 / 255, 120 / 255, 80 / 255)),
    ('Yellow', '#96FE69'),
    ('Yellow', (240 / 255, 228 / 255, 66 / 255))
]



def params_dict_to_filename(params_dict,skip_handles):
    formatted_params = []
    scaling_keys = {'Eac0', 'dm0', 'vlattice', 'v1', 'Llattice'}  # Keys that require scaling (e.g. between kHz and MHz)

    for key, value in params_dict.items():
        # Skip specified parameters
        if key in skip_handles:
            continue
        
        # Check if value has a unit tuple
        if isinstance(value, tuple):
            scaled_value0, unit = value
            scale_factor = 1e3 if key in scaling_keys else 1.0
            scaled_value = scaled_value0 * scale_factor
            
            formatted_value = f"{abs(scaled_value):.0f}".replace("+", "").replace("-", "")
            formatted_params.append(f"{key}_{formatted_value}{unit}")
        elif isinstance(value, str):
            formatted_params.append(f"{key}_{value}")
        else:
            # Handle parameters without units
            formatted_value = f"{abs(value):.0f}".replace("+", "").replace("-", "")
            formatted_params.append(f"{key}_{formatted_value}")
    
    # Join all formatted parameters with underscores
    filename = "sim_" + "_".join(formatted_params)
    # print(filename)
    return filename
    

def filename_to_params_dict(filename, data_folder_name, skip_handles=[]):
    
    scaling_keys = {'Eac0', 'dm0', 'vlattice', 'v1', 'Llattice'}  # Keys that require scaling (e.g. between kHz and MHz)
    

    # Remove data folder prefix and .qu suffix
    if filename.startswith(data_folder_name):
        filename = filename[(len(data_folder_name)+3):] #remove '/q1' or '/q2' as well
    else:
        filename = filename[2:] #remove '/q1' or '/q2' as well
        
    if filename.endswith('.qu'):
        filename = filename[:-3]
    
    # Remove leading and "sim_"
    filename = filename.lstrip("_").replace("sim_", "")
    # print(filename)
    
    # Initialize dictionary and parse parameters
    params_dict = {}
    
    # Split filename by underscores and reconstruct key-value pairs
    param_pairs = filename.split("_")
    i = 0
    while i < len(param_pairs):
        key = param_pairs[i]
        value_str = param_pairs[i + 1]
        
        # print(value_str, key)
        
        # Check for a unit (detect if last character is non-numeric)
        # print(value_str)
        match = re.match(r"(\d+)([a-zA-Z]*)$", value_str)
        # print(match)
        if match:
            value, unit = match.groups()
            scaled_value = float(value) / (1e3 if key in scaling_keys else 1e0)
            params_dict[key] = (scaled_value, unit)
        else:
            if key == 'OPphDist':
                params_dict[key] = value_str
            else:
                params_dict[key] = float(value_str)
        
        # Skip to next key-value pair
        i += 2
    
    #  Remove skipped handles
    for handle in skip_handles:
        params_dict.pop(handle, None)
    
    return params_dict


def check_file_existance(prefix, exp_params_dict, data_folder_name, skip_handles=[], auto_skipFlag=False, 
                         auto_def_response='no'):
    
    # Checks if file already exists. Output True if exists and should not be overwritten.
    # Output False for either non-existence or if requested to overwrite the file
    
    filename = prefix + params_dict_to_filename(exp_params_dict, skip_handles) + ".qu"  # reuse inner function
    file_path = os.path.join(data_folder_name, filename)
    # print(file_path)
    
    # Ask user to decide if overwrite the file
    if os.path.exists(file_path):
        print('(File:) The file already exists! See path:', file_path)
        
        if auto_skipFlag:
            # always avoid overwritting
            response = auto_def_response
        else:
            # ask user if wants to overwrite the file
            response = input("Do you want to overwrite the file? (yes/no): ").strip().lower() 
        
        if response == "yes":
            return False, file_path
        else:
            return True, file_path
    else: 
        # print('(File:) New file:', file_path)
        return False, file_path 


def save_qutip_data(data, prefix, exp_params_dict, data_folder_name, skip_handles=[]):    
    # prefix = "q2" 
    
    name_experiment_results = prefix + params_dict_to_filename(exp_params_dict, skip_handles)
    
    # Full file path
    file_path = os.path.join(data_folder_name, name_experiment_results)

    if not os.path.exists(data_folder_name):
        os.makedirs(data_folder_name)
        
    # Save the data using qsave
    qsave(data, file_path)
    print(f"Data saved as {file_path}")
    
    
def create_plot_label(params_dict, legend_params):
    # Initialize an empty list to store parts of the label
    label_parts = []

    # Loop through the requested legend parameters
    for param in legend_params:
        if param in params_dict:
            if isinstance(params_dict[param], str):
                # Create the label part for this parameter and add it to the list
                label_parts.append(f"{params_dict[param]}")
            
            else:       
                value, unit = params_dict[param]
                
                # Scale and format the value as needed
                scaled_value = value * (1e3 if param in ['Eac0', 'dm0', 'v1', 'vlattice'] else 1e0)
                formatted_value = f"{abs(scaled_value):.0f}"
        
                # Create the label part for this parameter and add it to the list
                label_parts.append(f"{param}={formatted_value}{unit}")
    
    # Join all parts with commas or spaces if needed
    plot_label = ", ".join(label_parts)
    return plot_label


def plot_expectation_vs_time(data, exp_params_dict, figure_save_folder_path, 
                              savePlotFlag=False, newFigureFlag=False, skip_handles=[],legend_params =[],
                              timepoint_list=[], legend_str_explicit=None, mode_plt=0):
    # Check normalisation 
    if True:    
        final_state = data.states[-1]  # Extract the final state
        fock_probs = np.abs(final_state.full())**2  # Full state vector in array form
        fock_probs = fock_probs.flatten()
        print('(Running calculations:) The normalisation check:',sum(fock_probs))
    
    
    # Fig. 1: n-bar vs interaction-time
    if newFigureFlag:
        # Create a figure
        plt.figure()    
    else:  
        if plt.fignum_exists(num='expec_values'):  # Check if figure 1 exists
            fig1 = plt.figure(num='expec_values')  # Use the existing figure
        else:
            fig1 = plt.figure(num='expec_values')  # Create a new figure
        
        
    # Create label string for the plot
    if legend_str_explicit:
        label_string = legend_str_explicit
    else:
        label_string = create_plot_label(exp_params_dict, legend_params)
        

    # Get interaction times
    tlist = np.array(data.times)
    
    # Calculate n-bar the expectation values
    expectation_values = data.expect[mode_plt]
    if expectation_values.size == 0:
        # If the expectation_values not available, calculate from state values.  
        # Issue to overwritte the expectation values when concatenating files because setattr unabled in qutip
        if 'NOP' in exp_params_dict:
            # Two ions calculations 
            NIP = int(exp_params_dict['NIP'][0])
            NOP = int(exp_params_dict['NOP'][0])
            aIP_op = tensor(destroy(NIP),qeye(NOP))
            aOP_op = tensor(qeye(NIP),destroy(NOP))
            nIP_op = aIP_op.dag() * aIP_op
            nOP_op = aOP_op.dag() * aOP_op
            expectation_values = np.array([expect(nIP_op, state) for state in data.states])
        
        elif 'Nfock' in exp_params_dict:
            # Single ion calculations 
            Nfock = int(exp_params_dict['Nfock'][0])
            a_op = destroy(Nfock)
            n_op = a_op.dag() * a_op
            expectation_values = np.array([expect(n_op, state) for state in data.states])

        elif 'nOPfixed' in exp_params_dict:
            # 1D 2 ions calculations 
            NIP, _ = exp_params_dict["NIP"]
            a_op = destroy(int(NIP))
            n_op = a_op.dag() * a_op
            
            expectation_values = np.array([expect(n_op, state) for state in data.states])
                
        print('(Plotting fig.1) Calculated expectation values from states.')
    
    # Plot the values
    plt.plot(tlist,expectation_values, label = label_string, linewidth=1)
    fontsize_choice = 12
    plt.xlabel('Time [us]', fontsize=fontsize_choice)
    plt.ylabel(r'$\bar{n}_-$', fontsize=fontsize_choice)
    
    if False:
        plt.title('Fig.1 Average motional state occupation', fontsize=fontsize_choice)
    
    # Add timestamps
    if timepoint_list:
        for index, timepoint_list_point in enumerate(timepoint_list):
            plt.plot(tlist[timepoint_list_point], expectation_values[timepoint_list_point], 'o', color=default_colors[index+5][1], markersize=4, label='_nolegend_')
    
    # Simple save of the figure
    if savePlotFlag and figure_save_folder_path is not None:
        
        # Generate the filename using params_dict_to_filename
        filename = params_dict_to_filename(exp_params_dict, skip_handles)
        fig_name = os.path.join(figure_save_folder_path, 'expectPlot_' + filename + '.png')
        
        # Save figure 
        fig1.savefig(fig_name)
        print(f"(Plotting fig. 1) Plot saved as: {fig_name}")
        
    # Add legend and draw after saving the figure!
    plt.legend()
    plt.draw()
    
      
    
def plot_phonon_distribution(data, exp_params_dict, figure_save_folder_path,
                             timepoint_list=[-1], iterator=0, savePlotFlag=False, 
                             newFigureFlag=False, skip_handles=[], title_params=[],
                             fitCoherentFlag=False, fitGaussianFlag=False, mode_plt=0):
    
    
    def overplot_coherent_dist(data_input, timestamp, line_color='blue', label_str='Coherent state'):
        #add functionality of oveplotting a coherent distribution to the plotted simulated data 
        state = data_input.states[timestamp].ptrace(0)
        N = state.dims[0][0]
        n_operator = num(N)  # Create the photon number operator
        mean_n = expect(n_operator, state)
        print("Mean photon number:", mean_n)
        coherent_state = coherent(N, np.sqrt(mean_n)) #alpha=sqrt(n_bar)
        coherent_dist = np.abs(coherent_state.full())**2    
        plt.plot(np.arange(N), coherent_dist, '--', color=line_color, label=label_str)
        
        if True: ax1.set_xlim(0, (mean_n+5*np.sqrt(mean_n))) #set the max. to be 2 times the width
    
    
    def fit_gaussian_to_distribution(data):
        # Define a Gaussian function
        def gaussian(n, A, mean, sigma):
            return A * np.exp(-0.5 * ((n - mean) / sigma)**2)
        
        # Extract the photon number distribution
        state = data.states[timepoint_list[0]].ptrace(0)  # Extract the state
        photon_dist = np.real(np.diag(state.full()))  # Photon number probabilities
        
        # Define the x-axis values (photon numbers)
        n_values = np.arange(len(photon_dist))

        # Initial guess for fitting parameters: A (amplitude), mean, sigma
        initial_guess = [1,np.argmax(photon_dist), np.sqrt(np.argmax(photon_dist))]

        # plt.plot(n_values, photon_dist, 'o-')

        # Perform Gaussian fit
        params, _ = curve_fit(gaussian, n_values, photon_dist, p0=initial_guess)
        A_fit, mean_fit, sigma_fit = params

        # print(params)
        # Overlay the Gaussian fit
        gaussian_fit = gaussian(n_values, A_fit, mean_fit, sigma_fit)
        
        ax1.plot(n_values, gaussian_fit, 'r--', label='Gaussian fit')
        
        if True and params[1] > 0:
            ax1.set_xlim(0, (params[1]+5*np.sqrt(params[1]))) #set the max. to be 2 times the width
                
       
        # Labels and legend
        ax1.legend()
        plt.show()
    
    if newFigureFlag:
        fig2 = plt.figure()
        ax1 = fig2.add_subplot(111)
    else:  
        if plt.fignum_exists(num='phonon_values'):  # Check if figure 1 exists
            fig2 = plt.figure(num='phonon_values')  # Use the existing figure
            ax1 = fig2.axes[0]  # Get the existing axi
            
        else:
            fig2 = plt.figure(num='phonon_values')  # Create a new figure 
            ax1 = fig2.add_subplot(111)

    # Fig 2 - Phonon distribution at defined timepoint
    tlist = np.array(data.times)
    if len(timepoint_list)==0:
        pass
    elif len(timepoint_list)==1:
        # Plot distributions for curves of different settings - should correspond to line colors on other plots
        plot_fock_distribution(data.states[timepoint_list[0]].ptrace(mode_plt), fig=fig2, ax=ax1, color=default_colors[iterator+5][1])  
        if fitCoherentFlag:overplot_coherent_dist(data, timepoint_list[0], line_color=default_colors[iterator+5][1]) #default_colors[iterator][1])
        if fitGaussianFlag: fit_gaussian_to_distribution(data)
        
    elif len(timepoint_list)>1:
        for index, timepoint in enumerate(timepoint_list):
            label_string = f'{tlist[timepoint]:.0f} µs'
            plot_fock_distribution(data.states[timepoint].ptrace(mode_plt), fig=fig2, ax=ax1, 
                                   color=default_colors[index+5][1])
            plt.plot([], [], color=default_colors[index+5][1], label=label_string) # dummy to add legend 
            if fitCoherentFlag: overplot_coherent_dist(data, timepoint, line_color=default_colors[index+5][1], label_str='')
            if fitGaussianFlag: fit_gaussian_to_distribution(data)
    
    # Create variable string for the plot
    title_string = 'Phonon Distribution. ' + str(create_plot_label(exp_params_dict, title_params))
    fontsize_choice = 12
    ax1.set_title(title_string, fontsize=fontsize_choice)  # Set title font size
    ax1.set_xlabel(r'Fock state, $\bar{n}_-$', fontsize=fontsize_choice)  # Set x-label font size
    ax1.set_ylabel('Probability', fontsize=fontsize_choice)  # Set y-label font size

    # Save the plot if the flag is set and a valid folder path is provided
    if savePlotFlag and figure_save_folder_path is not None:
        # Generate the filename using params_dict_to_filename
        filename = params_dict_to_filename(exp_params_dict, skip_handles)

        fig_name = os.path.join(figure_save_folder_path, 'phononDistPlot_'+'mode'+str(mode_plt) + filename + '.png')
        fig2.savefig(fig_name)
        print(f"Plot saved as: {fig_name}")
        

    # ax1.legend()
    plt.draw()    
    
